import torch
import random
from typing import List, Any
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
import ray
import asyncio
from coati.experience_maker.base import Experience
from coati.replay_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.replay_buffer import ReplayBuffer
from threading import Lock
import copy

class DetachedReplayBuffer:
    '''
        Detached replay buffer. Share Experience across workers on the same node. 
        Therefore a trainer node is expected to have only one instance. 
        It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
    
    Args:
        sample_batch_size: Batch size when sampling. Exp won't enqueue until they formed a batch.
        tp_world_size: Number of workers in the same tp group
        limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
        cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
    '''

    def __init__(self, sample_batch_size: int, tp_world_size: int = 1, limit : int = 0, cpu_offload: bool = True) -> None:
        self.cpu_offload = cpu_offload
        self.sample_batch_size = sample_batch_size
        self.limit = limit
        self.items = Queue(self.limit, actor_options={"num_cpus":1})
        self.batch_collector : List[BufferItem] = []

        '''
        Workers in the same tp group share this buffer and need same sample for one step.
            Therefore a held_sample should be returned tp_world_size times before it could be dropped.
            worker_state records whether a worker got the held_sample
        '''
        self.tp_world_size = tp_world_size
        self.worker_state = [False] * self.tp_world_size
        self.held_sample = None
        self._worker_state_lock = Lock()

    @torch.no_grad()
    def append(self, experience: Experience) -> None:
        '''
        Expected to be called remotely.
        '''
        if self.cpu_offload:
            experience.to_device(torch.device('cpu'))
        items = split_experience_batch(experience)
        self.batch_collector.extend(items)
        while len(self.batch_collector) >= self.sample_batch_size:
            items = self.batch_collector[:self.sample_batch_size]
            experience = make_experience_batch(items)
            self.items.put(experience, block=True)
            self.batch_collector = self.batch_collector[self.sample_batch_size:]

    def clear(self) -> None:
        # self.items.close()
        self.items.shutdown()
        self.items = Queue(self.limit)
        self.worker_state = [False] * self.tp_world_size
        self.batch_collector = []
     
    @torch.no_grad()
    def sample(self, worker_rank = 0, to_device = "cpu") -> Experience:
        self._worker_state_lock.acquire()
        if not any(self.worker_state):
            self.held_sample = self._sample_and_erase()
        self.worker_state[worker_rank] = True
        if all(self.worker_state):
            self.worker_state = [False] * self.tp_world_size
            ret = self.held_sample
        else:
            ret = copy.deepcopy(self.held_sample)
        self._worker_state_lock.release()
        ret.to_device(to_device)
        return ret

    @torch.no_grad()
    def _sample_and_erase(self) -> Experience:
        ret = self.items.get(block=True)
        return ret

    def get_length(self) -> int:
        ret = self.items.qsize()
        return ret